import argparse
import logging
import sys
import requests
from typing import Dict, Any

from gpustack import __version__, __git_commit__
from gpustack.cmd.start import load_config_from_yaml
from gpustack.config.config import Config
from gpustack.utils.envs import get_gpustack_env
from gpustack.utils.config import (
    WHITELIST_CONFIG_FIELDS,
    coerce_value_by_field,
    filter_whitelisted_yaml_config,
)
from gpustack.logging import setup_logging


logger = logging.getLogger(__name__)


class OptionalBoolAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=None, **kwargs):
        if nargs is not None:
            raise ValueError("nargs not allowed")
        super(OptionalBoolAction, self).__init__(
            option_strings, dest, nargs=0, **kwargs
        )
        self.default = kwargs.get("default")

    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, True)


def setup_reload_config_cmd(subparsers: argparse._SubParsersAction):
    parser: argparse.ArgumentParser = subparsers.add_parser(
        "reload-config",
        help="Reload GPUStack configuration.",
        description=("Reload GPUStack configuration via --set, --file, or --list."),
    )

    parser.add_argument(
        "--set",
        action="append",
        help=(
            "Set a single configuration value: --set key=value (key in hyphen-case). "
            "Only whitelisted fields are applied. "
            "Values are coerced by target field type. "
            "Lists accept comma-separated strings "
            "(e.g., --set allow-origins=https://a.com,https://b.com). "
            "Dicts require JSON string "
            "(e.g., --set system-reserved='{\"ram\":2,\"vram\":1}'). "
            "Invalid JSON will cause an error and exit."
        ),
    )
    parser.add_argument(
        "--file",
        type=str,
        help="Load configuration from YAML file: --file /path/to/gpustack_config.yaml",
    )
    parser.add_argument(
        "--list",
        action=OptionalBoolAction,
        help=(
            "List whitelisted fields that can be updated, can't use with --set or --file."
        ),
        default=False,
    )
    parser.add_argument(
        "--api-key",
        type=str,
        help="API Key to authenticate as admin.",
        default=get_gpustack_env("API_KEY"),
    )

    parser.add_argument(
        "--server-port",
        type=int,
        help="Port of the GPUStack API server to target.",
        default=get_gpustack_env("API_PORT"),
    )
    parser.add_argument(
        "--worker-port",
        type=int,
        help="Port of the GPUStack worker to target.",
        default=get_gpustack_env("WORKER_PORT"),
    )

    parser.set_defaults(func=run)


def run(args):
    try:
        logger.info("Starting configuration reload...")
        logger.info(f"GPUStack version: {__version__} ({__git_commit__})")
        if handle_list_mode(args):
            return

        cfg = parse_args_with_filter(args, {})
        payload = {}
        for field in WHITELIST_CONFIG_FIELDS:
            if hasattr(cfg, field):
                value = getattr(cfg, field)
                if value is not None:
                    payload[field] = value

        setup_logging(cfg.debug)
        apply_runtime_updates(payload, args)
        display_config_summary(cfg)

    except Exception as e:
        logger.error(f"Failed to reload configuration: {e}")
        sys.exit(1)


def display_config_summary(cfg):
    """Display a summary of the reloaded configuration - only show whitelisted fields."""
    logger.info("=== Configuration Reload Summary ===")

    for field in WHITELIST_CONFIG_FIELDS:
        if hasattr(cfg, field):
            value = getattr(cfg, field)
            if value is not None:
                logger.info(f"- reload: {field} = {value}")
    logger.info("Configuration successfully reloaded.")

    logger.info("=====================================")


def parse_args_with_filter(args: argparse.Namespace, filtered_changes: Dict[str, Any]):
    """
    Parse arguments with filtered configuration changes.

    This function reuses the logic from start.py but applies whitelist filtering.
    """

    config_data = {}

    # Handle config file if provided
    if getattr(args, "file", None):
        yaml_data = load_config_from_yaml(args.file)
        filtered_yaml_data = filter_whitelisted_yaml_config(yaml_data or {})
        config_data.update(filtered_yaml_data)

    if getattr(args, "set", None):
        for item in args.set:
            if "=" not in item:
                raise Exception(f"Invalid --set value: {item}. Use key=value")
            k, v = item.split("=", 1)
            key = k.replace("-", "_")
            if key in WHITELIST_CONFIG_FIELDS:
                config_data[key] = coerce_value_by_field(key, v)

    # Apply filtered command line changes (these override config file)
    for key, value in filtered_changes.items():
        config_data[key] = value

    # Create config with filtered data - only use the filtered config data
    # Don't call set_common_options/set_server_options/set_worker_options
    # as they would re-apply all command line arguments including blocked ones
    return Config(**config_data)


def apply_runtime_updates(
    payload: Dict[str, Any],
    args: argparse.Namespace,
):
    api_key = getattr(args, "api_key", None)
    server_port = getattr(args, "server_port") or 8080
    worker_port = getattr(args, "worker_port") or 10150
    urls = [
        f"http://127.0.0.1:{server_port}/debug/config",
        f"http://127.0.0.1:{worker_port}/debug/config",
    ]
    for url in urls:
        try:
            headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
            resp = requests.put(url, json=payload, headers=headers)
            if resp.status_code == 200:
                logger.info(f"Applied runtime config via {url}")
            else:
                logger.warning(f"Failed to apply config via {url}: {resp.status_code}")
        except Exception as e:
            logger.warning(f"Failed to apply config via {url}: {e}")


def list_runtime_values(
    api_key: str | None = None,
    server_port: int | None = None,
    worker_port: int | None = None,
) -> Dict[str, Dict[str, Any]]:
    results: Dict[str, Dict[str, Any]] = {}
    s_port = server_port or 8080
    w_port = worker_port or 10150
    endpoints = {
        "server": f"http://127.0.0.1:{s_port}/debug/config",
        "worker": f"http://127.0.0.1:{w_port}/debug/config",
    }
    for scope, url in endpoints.items():
        try:
            headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
            resp = requests.get(url, timeout=2, headers=headers)
            if resp.status_code == 200:
                results[scope] = resp.json()
        except Exception:
            continue
    return results


def handle_list_mode(args) -> bool:
    if not getattr(args, "list", False):
        return False
    print("Whitelisted fields:")
    for field in sorted(WHITELIST_CONFIG_FIELDS):
        print(f"- {field.replace('_', '-')}")
    runtime_values = list_runtime_values(
        api_key=getattr(args, "api_key", None),
        server_port=getattr(args, "server_port", None),
        worker_port=getattr(args, "worker_port", None),
    )
    if runtime_values:
        print("Current config values:")
        for scope, conf in runtime_values.items():
            for field in sorted(WHITELIST_CONFIG_FIELDS):
                if field in conf and conf[field] is not None:
                    print(f"- {scope}: {field.replace('_', '-')} = {conf[field]}")
    return True
